-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix several bugs for enabling Paddle to train with CINN. #36739
Conversation
Thanks for your contribution! |
set_cinn_flag(False) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() | ||
do_test(tempfile.mkdtemp(prefix="dots_")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种debug信息还用保留么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
@@ -136,6 +153,18 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, | |||
} | |||
old_var2new_var[var] = sub_node; | |||
} | |||
for (auto* var : cluster_inputs) { | |||
if (var->Var()) { | |||
auto* sub_node = subgraph->CreateVarNode(var->Var()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这儿为啥不要var->Var() == nullptr
时CreateEmptyNode
啊?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为子图不需要这种var_desc为null的节点,CINN编译的时候也用不到。而这种空var只是为了加上依赖关系,在修改后的大图上已经有提现。
|
||
def test_run_with_cinn(self): | ||
do_test(self.tmpdir) | ||
set_cinn_flag(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这儿set_cinn_flag(False)
应该删了,不然后面的测试都不会运行了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update.
exe.run(startup_program) | ||
|
||
build_strategy = paddle.static.BuildStrategy() | ||
build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, "viz") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些deubg信息没必要保留吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调试信息还是保留吧,这个类会处理这些保存的文件,退出时即删除。
@@ -130,12 +147,25 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, | |||
for (auto* var : cluster_internals) { | |||
Node* sub_node; | |||
if (var->Var() == nullptr) { | |||
// TODO(wzzju): If this case occurs, there maybe bugs when using CINN. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个bug还在么?如果还在是不是最好弄个PADDLE_ENFORCE
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…e#36739) * Update the content of `test_parallel_executor_run_cinn.py`. * Fix some bugs in the topological sort and `CreateNewSubGraph`. * Update the CINN commit id used by Paddle. * Update the unit test to `add+relu`. * Update according to reviewers' suggestion.
PR types
Bug fixes
PR changes
Others
Describe
#36739 (本PR) 和 #36698 均用于解决Paddle训练接入CINN全流程打通过程中遇到的问题。所有问题列举如下:
build_cinn_pass
中找到并创建的子图中可能存在ControlVar,这些ControlVar的VarDesc为NULL,需要对这种corner case进行特殊处理,方案详见 Fix the null ptr bug in build_cinn_pass. #36698 。build_cinn_pass
放在ParallelExecutorPassBuilder的最后加入,会与后面的显存优化Pass功能进行冲突,为了避免与其他Pass功能冲突,将build_cinn_pass
提前至第一个Pass运行。build_cinn_pass
在创建CinnLaunchOp后,需要为其设置op_role
属性值,否则在其之后运行的Pass会报错找不到改属性。build_cinn_pass
的子图创建逻辑,处理输出节点亦可作为子图内部输入的情况。build_cinn_pass
节点之间连接关系,以及无用节点的删除逻辑。add + relu
且固定数值的模型运算逻辑(方便精度对齐),并添加调试所需的信息。